{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f445c1d1-acb9-431e-a7ff-50c41f064359",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n",
      "[nltk_data] Downloading package stopwords to\n",
      "[nltk_data]     /Users/jerryliu/nltk_data...\n",
      "[nltk_data]   Package stopwords is already up-to-date!\n"
     ]
    }
   ],
   "source": [
    "from utils import get_train_str, get_train_and_eval_data, get_eval_preds, train_prompt\n",
    "\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "warnings.simplefilter(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf3cbd90-d5e1-4c30-a3bc-8b39fbd85d70",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load up the titanic data\n",
    "train_df, train_labels, eval_df, eval_labels = get_train_and_eval_data(\"data/train.csv\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "fa2634f9-cb33-4f1e-81f9-3a3b285e2580",
   "metadata": {},
   "source": [
    "## Few-shot Prompting with GPT-3 for Titanic Dataset\n",
    "In this section, we can show how we can prompt GPT-3 on its own (without using GPT Index) to attain ~80% accuracy on Titanic! \n",
    "\n",
    "We can do this by simply providing a few example inputs. Or we can simply provide no example inputs at all (zero-shot). Both achieve the same results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0698fd2-1361-49ae-8c17-8124e9b932a4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The following structured data is provided in \"Feature Name\":\"Feature Value\" format.\n",
      "Each datapoint describes a passenger on the Titanic.\n",
      "The task is to decide whether the passenger survived.\n",
      "Some example datapoints are given below: \n",
      "-------------------\n",
      "{train_str}\n",
      "-------------------\n",
      "Given this, predict whether the following passenger survived. Return answer as a number between 0 or 1. \n",
      "{eval_str}\n",
      "Survived: \n"
     ]
    }
   ],
   "source": [
    "# first demonstrate the prompt template\n",
    "print(train_prompt.template)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b39e2e7-be07-42f8-a27a-3419e84cfb2c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Example datapoints in `train_str`: \n",
      "This is the Data:\n",
      "Age:28.0\n",
      "Embarked:S\n",
      "Fare:7.8958\n",
      "Parch:0\n",
      "Pclass:3\n",
      "Sex:male\n",
      "SibSp:0\n",
      "This is the correct answer:\n",
      "Survived: 0\n",
      "\n",
      "This is the Data:\n",
      "Age:17.0\n",
      "Embarked:S\n",
      "Fare:7.925\n",
      "Parch:2\n",
      "Pclass:3\n",
      "Sex:female\n",
      "SibSp:4\n",
      "This is the correct answer:\n",
      "Survived: 1\n",
      "\n",
      "This is the Data:\n",
      "Age:30.0\n",
      "Embarked:S\n",
      "Fare:16.1\n",
      "Parch:0\n",
      "Pclass:3\n",
      "Sex:male\n",
      "SibSp:1\n",
      "This is the correct answer:\n",
      "Survived: 0\n",
      "\n",
      "This is the Data:\n",
      "Age:22.0\n",
      "Embarked:S\n",
      "Fare:7.25\n",
      "Parch:0\n",
      "Pclass:3\n",
      "Sex:male\n",
      "SibSp:0\n",
      "This is the correct answer:\n",
      "Survived: 0\n",
      "\n",
      "This is the Data:\n",
      "Age:45.0\n",
      "Embarked:S\n",
      "Fare:13.5\n",
      "Parch:0\n",
      "Pclass:2\n",
      "Sex:female\n",
      "SibSp:0\n",
      "This is the correct answer:\n",
      "Survived: 1\n",
      "\n",
      "This is the Data:\n",
      "Age:25.0\n",
      "Embarked:S\n",
      "Fare:0.0\n",
      "Parch:0\n",
      "Pclass:3\n",
      "Sex:male\n",
      "SibSp:0\n",
      "This is the correct answer:\n",
      "Survived: 1\n",
      "\n",
      "This is the Data:\n",
      "Age:18.0\n",
      "Embarked:S\n",
      "Fare:20.2125\n",
      "Parch:1\n",
      "Pclass:3\n",
      "Sex:male\n",
      "SibSp:1\n",
      "This is the correct answer:\n",
      "Survived: 0\n",
      "\n",
      "This is the Data:\n",
      "Age:33.0\n",
      "Embarked:S\n",
      "Fare:9.5\n",
      "Parch:0\n",
      "Pclass:3\n",
      "Sex:male\n",
      "SibSp:0\n",
      "This is the correct answer:\n",
      "Survived: 0\n",
      "\n",
      "This is the Data:\n",
      "Age:24.0\n",
      "Embarked:S\n",
      "Fare:65.0\n",
      "Parch:2\n",
      "Pclass:2\n",
      "Sex:female\n",
      "SibSp:1\n",
      "This is the correct answer:\n",
      "Survived: 1\n",
      "\n",
      "This is the Data:\n",
      "Age:26.0\n",
      "Embarked:S\n",
      "Fare:7.925\n",
      "Parch:0\n",
      "Pclass:3\n",
      "Sex:female\n",
      "SibSp:0\n",
      "This is the correct answer:\n",
      "Survived: 1\n"
     ]
    }
   ],
   "source": [
    "# Get \"training\" prompt string\n",
    "train_n = 10\n",
    "eval_n = 40\n",
    "train_str = get_train_str(train_df, train_labels, train_n=train_n)\n",
    "print(f\"Example datapoints in `train_str`: \\n{train_str}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "819a06f7-3171-4edb-b90c-0a3eae308a04",
   "metadata": {},
   "source": [
    "#### Do evaluation with the training prompt string"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a7f2202-518c-41a3-80ab-1e98bbcca903",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "import numpy as np\n",
    "\n",
    "eval_preds = get_eval_preds(train_prompt, train_str, eval_df, n=eval_n)\n",
    "eval_label_chunk = eval_labels[:eval_n]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64323a4d-6eea-4e40-9eac-b2deed60192b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ACCURACY: 0.8\n"
     ]
    }
   ],
   "source": [
    "acc = accuracy_score(eval_label_chunk, np.array(eval_preds).round())\n",
    "print(f\"ACCURACY: {acc}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "11790d28-8f34-42dd-b11f-6aad21fd5f46",
   "metadata": {},
   "source": [
    "#### Do evaluation with no training prompt string! "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aaf993e5-c363-4f18-a28f-09761e49cb6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "import numpy as np\n",
    "\n",
    "eval_preds_null = get_eval_preds(train_prompt, \"\", eval_df, n=eval_n)\n",
    "eval_label_chunk = eval_labels[:eval_n]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3b8bcd5-5972-4ce5-9aa1-57460cdde199",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ACCURACY: 0.8\n"
     ]
    }
   ],
   "source": [
    "acc_null = accuracy_score(eval_label_chunk, np.array(eval_preds_null).round())\n",
    "print(f\"ACCURACY: {acc_null}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "8f0a5e4b-e627-4b47-a807-939813596594",
   "metadata": {},
   "source": [
    "## Extending with Summary Index"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "42a1ca28-96e9-4cd2-bd48-0673917ad057",
   "metadata": {},
   "source": [
    "#### Build Index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c59b030-855d-4e27-89c3-74c972d1bf19",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index import SummaryIndex\n",
    "from llama_index.schema import Document"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f9556de-e323-4318-bb71-cff75bf8c3c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "index = SummaryIndex([])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e27720fc-af36-40fd-8c55-41485248aa9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# insertion into index\n",
    "batch_size = 40\n",
    "num_train_chunks = 5\n",
    "\n",
    "for i in range(num_train_chunks):\n",
    "    print(f\"Inserting chunk: {i}/{num_train_chunks}\")\n",
    "    start_idx = i * batch_size\n",
    "    end_idx = (i + 1) * batch_size\n",
    "    train_batch = train_df.iloc[start_idx : end_idx + batch_size]\n",
    "    labels_batch = train_labels.iloc[start_idx : end_idx + batch_size]\n",
    "    all_train_str = get_train_str(train_batch, labels_batch, train_n=batch_size)\n",
    "    index.insert(Document(text=all_train_str))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "e78db088-6649-44db-b52a-766316713b96",
   "metadata": {},
   "source": [
    "#### Query Index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cb90564-1de2-412f-8318-d5280855004e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import query_str, qa_data_prompt, refine_prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77c1ae36-e0af-47bc-a656-4971af699755",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Which is the relationship between these features and predicting survival?'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "query_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c403710f-d4b3-4287-94f5-e275ea19b476",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> Starting query: Which is the relationship between these features and predicting survival?\n"
     ]
    }
   ],
   "source": [
    "query_engine = index.as_query_engine(\n",
    "    text_qa_template=qa_data_prompt,\n",
    "    refine_template=refine_prompt,\n",
    ")\n",
    "response = query_engine.query(\n",
    "    query_str,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2545ab1-980a-4fbd-8add-7ef957801644",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "There is no definitive answer to this question, as the relationship between the features and predicting survival will vary depending on the data. However, some possible relationships include: age (younger passengers are more likely to survive), sex (females are more likely to survive), fare (passengers who paid more for their ticket are more likely to survive), and pclass (passengers in first or second class are more likely to survive).\n"
     ]
    }
   ],
   "source": [
    "print(response)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "d0d7d260-2283-49f6-ac40-35c7071cc54d",
   "metadata": {},
   "source": [
    "#### Get Predictions and Evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7b98057-957c-48ef-be85-59ff9813d201",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The following structured data is provided in \"Feature Name\":\"Feature Value\" format.\n",
      "Each datapoint describes a passenger on the Titanic.\n",
      "The task is to decide whether the passenger survived.\n",
      "We discovered the following relationship between features and survival:\n",
      "-------------------\n",
      "{train_str}\n",
      "-------------------\n",
      "Given this, predict whether the following passenger survived. \n",
      "Return answer as a number between 0 or 1. \n",
      "{eval_str}\n",
      "Survived: \n",
      "\n",
      "\n",
      "`train_str`: \n",
      "\n",
      "There is no definitive answer to this question, as the relationship between the features and predicting survival will vary depending on the data. However, some possible relationships include: age (younger passengers are more likely to survive), sex (females are more likely to survive), fare (passengers who paid more for their ticket are more likely to survive), and pclass (passengers in first or second class are more likely to survive).\n"
     ]
    }
   ],
   "source": [
    "# get eval preds\n",
    "from utils import train_prompt_with_context\n",
    "\n",
    "train_str = response\n",
    "print(train_prompt_with_context.template)\n",
    "print(f\"\\n\\n`train_str`: {train_str}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "659c6a3f-1c5d-4314-87dc-908e76d50e4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# do evaluation\n",
    "from sklearn.metrics import accuracy_score\n",
    "import numpy as np\n",
    "\n",
    "eval_n = 40\n",
    "eval_preds = get_eval_preds(train_prompt_with_context, train_str, eval_df, n=eval_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7424e7d3-2576-42bc-b626-cf8088265004",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ACCURACY: 0.85\n"
     ]
    }
   ],
   "source": [
    "eval_label_chunk = eval_labels[:eval_n]\n",
    "acc = accuracy_score(eval_label_chunk, np.array(eval_preds).round())\n",
    "print(f\"ACCURACY: {acc}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gpt_retrieve_venv",
   "language": "python",
   "name": "gpt_retrieve_venv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
